{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fe5662bb",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# State Saving and Loading\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/brainpy/blob/master/docs_version2/tutorial_toolbox/state_saving_and_loading.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/brainpy/blob/master/docs_version2/tutorial_toolbox/state_saving_and_loading.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ba75189",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Being able to save and load the variables of a model is essential in brain dynamics programming. In this tutorial we describe how to save/load the variables in a  model. "
   ]
  },
  {
   "cell_type": "code",
   "id": "eff1932c",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "ExecuteTime": {
     "end_time": "2025-10-06T05:17:17.088269Z",
     "start_time": "2025-10-06T05:17:12.568513Z"
    }
   },
   "source": [
    "import numpy as np\n",
    "\n",
    "import brainpy as bp\n",
    "import brainpy.math as bm\n",
    "\n",
    "bp.math.set_platform('cpu')"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "id": "4ef65b38",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Saving and loading variables"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a88696576c7242c6",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "State saving and loading in BrainPy are managed by a **local** function and a **global** function. \n",
    "\n",
    "The **local function** is to save or load states in the current node. Particularly, ``save_state()`` and ``load_state()`` are local functions for saving and loading states. \n",
    "\n",
    "The **global function** is to save or load all states in the current and children nodes. Particularly, ``brainpy.save_state()`` and ``brainpy.load_state()`` are global functions for saving and loading states. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01b7ac95",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Here’s a simple example:"
   ]
  },
  {
   "cell_type": "code",
   "id": "bc2cab20",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "ExecuteTime": {
     "end_time": "2025-10-06T05:17:17.097455Z",
     "start_time": "2025-10-06T05:17:17.093277Z"
    }
   },
   "source": [
    "class SNN(bp.DynamicalSystem):\n",
    "  def __init__(self):\n",
    "    super().__init__()\n",
    "    self.var = bm.Variable(bm.zeros(1))\n",
    "    self.l1 = bp.dnn.Dense(28 * 28, 10, b_initializer=None)\n",
    "    self.l2 = bp.dyn.Lif(10, V_rest=0., V_reset=0., V_th=1., tau=2.0, spk_fun=bm.surrogate.Arctan())\n",
    "\n",
    "  def update(self, x):\n",
    "    return x >> self.l1 >> self.l2"
   ],
   "outputs": [],
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "id": "59a6abf6a8eabaa9",
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-10-06T05:17:17.551790Z",
     "start_time": "2025-10-06T05:17:17.100466Z"
    }
   },
   "source": [
    "net = SNN()"
   ],
   "outputs": [],
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "id": "bad4acc7d799b60d",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "##### State saving\n",
    "\n",
    "To extract the local variables in the ``net``:"
   ]
  },
  {
   "cell_type": "code",
   "id": "5eb9d839e47cf417",
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-10-06T05:17:17.564576Z",
     "start_time": "2025-10-06T05:17:17.558312Z"
    }
   },
   "source": [
    "net.save_state()"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'SNN0.var': Array([0.], dtype=float32)}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "id": "18709a9b365bf34f",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "To extract all variable under the ``net`` (including the local variables in the sub-nodes):"
   ]
  },
  {
   "cell_type": "code",
   "id": "a5e0fc0f7f424718",
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-10-06T05:17:17.578082Z",
     "start_time": "2025-10-06T05:17:17.571977Z"
    }
   },
   "source": [
    "bp.save_state(net)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'SNN0': {'SNN0.var': Array([0.], dtype=float32)},\n",
       " 'Dense0': {},\n",
       " 'Lif0': {'Lif0.V': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n",
       "  'Lif0.spike': Array([False, False, False, False, False, False, False, False, False,\n",
       "         False], dtype=bool)}}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "id": "c76da75caf11181d",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "If we want to save states of a model onto the disk, we can use ``brainpy.checkpoints.save_pytree``. "
   ]
  },
  {
   "cell_type": "code",
   "id": "1b3cf2ec8272839f",
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-10-06T05:17:17.604778Z",
     "start_time": "2025-10-06T05:17:17.596541Z"
    }
   },
   "source": [
    "bp.checkpoints.save_pytree('a.bp', net.state_dict())"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving checkpoint into a.bp\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "id": "e63dbe6b7a171cea",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "##### State loading\n",
    "\n",
    "To retrieve the saved states in the disk, one can use ``brainpy.checkpoints.load_pytree``. "
   ]
  },
  {
   "cell_type": "code",
   "id": "2cdc6d82d53317e7",
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-10-06T05:17:17.654474Z",
     "start_time": "2025-10-06T05:17:17.609953Z"
    }
   },
   "source": [
    "states = bp.checkpoints.load_pytree('a.bp')"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading checkpoint from a.bp\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "code",
   "id": "4d18c9fba2983e69",
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-10-06T05:17:17.664113Z",
     "start_time": "2025-10-06T05:17:17.659483Z"
    }
   },
   "source": [
    "states"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'SNN0': {'SNN0.var': array([0.], dtype=float32)},\n",
       " 'Dense0': {},\n",
       " 'Lif0': {'Lif0.V': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n",
       "  'Lif0.spike': array([False, False, False, False, False, False, False, False, False,\n",
       "         False])},\n",
       " 'ExponentialEuler0': {}}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 8
  },
  {
   "cell_type": "markdown",
   "id": "29a23a960953148a",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "After loading the model onto the memory, we can assign the loaded states to the corresponding variable by using ``load_state_dict()`` function."
   ]
  },
  {
   "cell_type": "code",
   "id": "a585a32ef51654b",
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-10-06T05:17:17.674683Z",
     "start_time": "2025-10-06T05:17:17.668604Z"
    }
   },
   "source": [
    "bp.load_state(net, states)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "StateLoadResult(missing_keys=[], unexpected_keys=[])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 9
  },
  {
   "cell_type": "markdown",
   "id": "1aeba1f9",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "- ``brainpy.checkpoints.save_pytree(filename: str, target: PyTree, overwrite: bool = True, async_manager: Optional[AsyncManager] = None, verbose: bool = True)``  function requires you to provide a `filename` which is the path where checkpoint files will be stored. You also need to supply a `target`, which is a state dict object. An optional `overwrite` argument allows you to decide whether to overwrite existing checkpoint files \n",
    "if a checkpoint for the current step or a later one already exists. If you provide an `async_manager`, the save operation will be non-blocking on the main thread, but note that this is only suitable for a single host. However, any ongoing save will still prevent \n",
    "new saves to ensure overwrite logic remains correct. Finally, you can set the `verbose` argument to specify if you want to receive printed information about the operation.\n",
    "\n",
    "- ``brainpy.checkpoints.load_pytree(filename: str, parallel: bool = True)`` function allows you to restore data from a given checkpoint file or a directory containing multiple checkpoints, which you specify with the `filename` argument. If you set the `parallel` argument to true, the function will attempt to load seekable checkpoints simultaneously for quicker results. When executed, the function returns the restored target from the checkpoint file. If no step is specified and there are no checkpoint files available, the function simply returns the input `target` without changes. If you specify a file path that doesn't exist, the function will also return the original `target`. This behavior mirrors the scenario where a directory path is given, but the directory hasn't been created yet.\n",
    "\n",
    "- ``brainpy.save_state(target)`` function retrieves the entire state of the ``target`` module and returns it as a dictionary. \n",
    "\n",
    "- ``brainpy.load_state(target, state_dict)`` function is used to import parameters and buffers from a provided `state_dict` into the current module and all its child modules. You need to provide the function with a `state_dict`, which is a dictionary containing the desired parameters and persistent buffers to be loaded. hen executed, the function returns a `StateLoadResult`, a named tuple with two fields:\n",
    "    - **missing_keys**: A list of keys that are present in the module but missing in the provided `state_dict`.\n",
    "    - **unexpected_keys**: A list of keys found in the `state_dict` that don't correspond to any part of the current module."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1417550bc0e4df4e",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## A simple example"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8512796",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Here is a example of model saving and loading in BrainPy using ``bp.checkpoints.save_pytree`` and ``bp.checkpoints.load_pytree`` functions. "
   ]
  },
  {
   "cell_type": "code",
   "id": "8c70c70c785f620c",
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-10-06T05:17:17.687249Z",
     "start_time": "2025-10-06T05:17:17.682511Z"
    }
   },
   "source": [
    "bm.set_dt(1.)\n",
    "\n",
    "class SNN(bp.DynamicalSystem):\n",
    "  def __init__(self, num_in, num_rec, num_out):\n",
    "    super().__init__()\n",
    "\n",
    "    # parameters\n",
    "    self.num_in = num_in\n",
    "    self.num_rec = num_rec\n",
    "    self.num_out = num_out\n",
    "\n",
    "    # neuron groups\n",
    "    self.r = bp.dyn.Lif(num_rec, tau=10., V_reset=0., V_rest=0., V_th=1.)\n",
    "    self.o = bp.dyn.Integrator(num_out, tau=5.)\n",
    "\n",
    "    # synapse: i->r\n",
    "    self.i2r = bp.Sequential(\n",
    "        comm=bp.dnn.Linear(num_in, num_rec, W_initializer=bp.init.KaimingNormal(scale=20.)),\n",
    "        syn=bp.dyn.Expon(num_rec, tau=10.),\n",
    "    )\n",
    "\n",
    "    # synapse: r->o\n",
    "    self.r2o = bp.Sequential(\n",
    "        comm=bp.dnn.Linear(num_rec, num_out, W_initializer=bp.init.KaimingNormal(scale=20.)),\n",
    "        syn=bp.dyn.Expon(num_out, tau=10.),\n",
    "    )\n",
    "\n",
    "  def update(self, spike):\n",
    "    return spike >> self.i2r >> self.r >> self.r2o >> self.o"
   ],
   "outputs": [],
   "execution_count": 10
  },
  {
   "cell_type": "code",
   "id": "edbfcc58",
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-10-06T05:17:20.242669Z",
     "start_time": "2025-10-06T05:17:17.690454Z"
    }
   },
   "source": [
    "num_in = 100\n",
    "num_rec = 10\n",
    "with bm.training_environment():\n",
    "    # out task is a two label classification task\n",
    "    net = SNN(num_in, num_rec, 2)  \n",
    "\n",
    "\n",
    "# We try to use this simple task to classify a random spiking data into two classes. \n",
    "num_step = 100\n",
    "num_sample = 256\n",
    "freq = 10  # Hz\n",
    "mask = bm.random.rand(num_step, num_sample, num_in)\n",
    "x_data = bm.zeros((num_step, num_sample, num_in))\n",
    "x_data[mask < freq * bm.get_dt() / 1000.] = 1.0\n",
    "y_data = bm.asarray(bm.random.rand(num_sample) < 0.5, dtype=bm.float_)\n",
    "indices = bm.arange(num_step)\n",
    "\n",
    "\n",
    "# training process\n",
    "class Trainer:\n",
    "  def __init__(self, net, opt):\n",
    "    self.net = net\n",
    "    self.opt = opt\n",
    "    opt.register_train_vars(net.train_vars().unique())\n",
    "    self.f_grad = bm.grad(self.f_loss, grad_vars=self.opt.vars_to_train, return_value=True)\n",
    "  \n",
    "  @bm.cls_jit(inline=True)\n",
    "  def f_loss(self):\n",
    "    self.net.reset(num_sample)\n",
    "    outs = bm.for_loop(self.net.step_run, (indices, x_data))\n",
    "    return bp.losses.cross_entropy_loss(bm.max(outs, axis=0), y_data)\n",
    "\n",
    "  @bm.cls_jit\n",
    "  def f_train(self):\n",
    "    grads, loss = self.f_grad()\n",
    "    self.opt.update(grads)\n",
    "    return loss\n",
    "\n",
    "\n",
    "trainer = Trainer(net=net, opt=bp.optim.Adam(lr=4e-3))\n",
    "\n",
    "loss = np.inf\n",
    "for i in range(10):\n",
    "  l = trainer.f_train()\n",
    "  if l < loss:\n",
    "    loss = l\n",
    "    states = {'net': bp.save_state(net), # save the state dict of the network in the checkpoint\n",
    "              'epoch_i': i,\n",
    "              'train_loss': loss}\n",
    "    bp.checkpoints.save_pytree('snn.bp', states, verbose=False) # save the checkpoint\n",
    "    print(f'Epoch {i}, loss {loss}')"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, loss 0.7809630632400513\n",
      "Epoch 1, loss 0.7564582824707031\n",
      "Epoch 2, loss 0.7331098914146423\n",
      "Epoch 3, loss 0.7209206819534302\n",
      "Epoch 4, loss 0.713200569152832\n",
      "Epoch 5, loss 0.6973298788070679\n",
      "Epoch 6, loss 0.689375638961792\n",
      "Epoch 8, loss 0.6860182285308838\n",
      "Epoch 9, loss 0.6855737566947937\n"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "cell_type": "code",
   "id": "621ac319",
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-10-06T05:17:23.600146Z",
     "start_time": "2025-10-06T05:17:23.546132Z"
    }
   },
   "source": [
    "# model loading\n",
    "state_dict = bp.checkpoints.load_pytree('snn.bp') # load the state dict\n",
    "bp.load_state(net, state_dict['net']) # unpack the state dict and load it into the network"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading checkpoint from snn.bp\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "StateLoadResult(missing_keys=[], unexpected_keys=[])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 12
  },
  {
   "cell_type": "markdown",
   "id": "a34074f2",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "```{note}\n",
    "By default, the model variables are retrived by the relative path. Relative path retrival usually results in duplicate variables in the returned ArrayCollector. Therefore, there will always be missing keys when loading the variables. \n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "422be59e",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Custom saving and loading"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8aef7f2d",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "You can make your own saving and loading functions easily.\n",
    "\n",
    "For customizing the saving and loading, users can overwrite ``save_state`` and ``load_state`` functions.\n",
    "\n",
    "Here is an example to customize:\n",
    "```python\n",
    "class YourClass(bp.DynamicSystem):\n",
    "  def __init__(self):\n",
    "    self.a = 1\n",
    "    self.b = bm.random.rand(10)\n",
    "    self.c = bm.Variable(bm.random.rand(3))\n",
    "    self.d = bm.var_list([bm.Variable(bm.random.rand(3)),\n",
    "                         bm.Variable(bm.random.rand(3))])\n",
    "\n",
    "  def save_state(self) -> dict:\n",
    "    state_dict = {'a': self.a,\n",
    "            'b': self.b,\n",
    "            'c': self.c}\n",
    "    for i, elem in enumerate(self.d):\n",
    "      state_dict[f'd{i}'] = elem.value\n",
    "\n",
    "    return state_dict\n",
    "\n",
    "  def load_state(self, state_dict):\n",
    "    self.a = state_dict['a']\n",
    "    self.b = bm.asarray(state_dict['b'])\n",
    "    self.c = bm.asarray(state_dict['c'])\n",
    "\n",
    "    for i in range(len(self.d)):\n",
    "      self.d[i].value = bm.asarray(state_dict[f'd{i}'])\n",
    "```\n",
    "\n",
    "\n",
    "- ``save_state(self)`` function saves the state of the object's variables and returns a dictionary where the keys are the names of the variables and the values are the variables' contents.\n",
    "\n",
    "- ``load_state(self, state_dict: Dict)`` function loads the state of the object's variables from a provided dictionary (``state_dict``). \n",
    "At firstly it gets the current variables of the object.\n",
    "Then, it determines the intersection of keys from the provided state_dict and the object's variables.\n",
    "For each intersecting key, it updates the value of the object's variable with the value from state_dict.\n",
    "Finally, returns A tuple containing two lists:\n",
    "  - ``unexpected_keys``: Keys in state_dict that were not found in the object's variables.\n",
    "  - ``missing_keys``: Keys that are in the object's variables but were not found in state_dict."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "brainpy",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
